{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fast_bert.data_abs import BertAbsDataBunch\n",
    "from fast_bert.learner_abs import BertAbsLearner\n",
    "from box import Box\n",
    "import logging\n",
    "import torch\n",
    "from pathlib import Path\n",
    "from transformers import BertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tokenizers import (ByteLevelBPETokenizer,\n",
    "                            BPETokenizer,\n",
    "                            SentencePieceBPETokenizer,\n",
    "                            BertWordPieceTokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = Path(\"../../summarisation/\")\n",
    "DATA_PATH = PATH/'data'\n",
    "MODEL_PATH = PATH/'model'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Box({\n",
    "    \"max_seq_length\": 512,\n",
    "    \"batch_size\": 8,\n",
    "    \"learning_rate\": 5e-3,\n",
    "    \"num_train_epochs\": 6,\n",
    "    \"fp16\": True,\n",
    "    \"model_name\": 'bertabs-finetuned-cnndm',\n",
    "    \"model_type\": 'bert'\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda') if torch.cuda.device_count() else torch.device('cpu')\n",
    "if torch.cuda.device_count() > 1:\n",
    "    args.multi_gpu = True\n",
    "else:\n",
    "    args.multi_gpu = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertWordPieceTokenizer(str(MODEL_PATH/'vocab.txt'), lowercase=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "databunch = BertAbsDataBunch(data_dir=DATA_PATH, tokenizer=tokenizer, device=device)\n",
    "databunch_old = BertAbsDataBunch(data_dir=DATA_PATH, tokenizer=args.model_name, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "databunch_with_data = BertAbsDataBunch(data_dir=DATA_PATH, tokenizer=args.model_name, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "databunch_with_data_new_tokeniser = BertAbsDataBunch(data_dir=DATA_PATH, tokenizer=tokenizer, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = BertAbsLearner.from_pretrained_model(databunch, MODEL_PATH, device, logger=logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = databunch_with_data.test_dl.dataset[0][1]\n",
    "texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "learner.predict_batch(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner_old = BertAbsLearner.from_pretrained_model(databunch_old, MODEL_PATH, device, logger=logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "learner_old.predict_batch(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
